{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SphereFace Network\n",
    "[SphereFace: Deep Hypersphere Embedding for Face Recognition](http://wyliu.com/papers/LiuCVPR17v3.pdf)\n",
    "\n",
    "refered to [tensorflow sphereface](https://github.com/hujun100/tensorflow-sphereface)\n",
    "\n",
    "### Structure: four convolution units\n",
    "- Conv1.x\n",
    "- Conv2.x\n",
    "- Conv3.x\n",
    "- Conv4.x\n",
    "\n",
    "### Split the network into different parts\n",
    "- Conv with strides\n",
    "- Conv with residual units"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Caffe implementation\n",
    "Using the [Netscope](http://ethereon.github.io/netscope/#/editor) to show the caffe network defined in prototxt\n",
    "\n",
    "**Note**: Each Conv followed by a prelu layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import ops\n",
    "import numpy as np\n",
    "tf.reset_default_graph()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.prelu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prelu(x, name = 'prelu'):\n",
    "    with tf.variable_scope(name):\n",
    "        alphas = tf.get_variable('alpha', x.get_shape()[-1], initializer=tf.constant_initializer(0.25), regularizer = l2_regularizer, dtype = tf.float32)\n",
    "    pos = tf.nn.relu(x)\n",
    "    neg = tf.multiply(alphas,(x - abs(x)) * 0.5)\n",
    "    return pos + neg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.conv with strides"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def first_conv(input, num_output, name):\n",
    "  with tf.variable_scope(name):\n",
    "    zero_init = tf.zeros_initializer()\n",
    "    network = tf.layers.conv2d(input, num_output, kernel_size = [3, 3], strides = (2, 2), padding = 'same', kernel_initializer = xavier, bias_initializer = zero_init, kernel_regularizer = l2_regularizer, bias_regularizer = l2_regularizer)\n",
    "    network = prelu(network, name = name)\n",
    "    return network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.conv with residual units"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def block(input, name, num_output):\n",
    "    with tf.variable_scope(name):\n",
    "        network = tf.layers.conv2d(input, num_output, kernel_size = [3, 3], strides = [1, 1], padding = 'same', kernel_initializer = tf.random_normal_initializer(stddev=0.01), use_bias = False , kernel_regularizer = l2_regularizer)\n",
    "        network = prelu(network, name = 'name'+ '1')\n",
    "        network = tf.layers.conv2d(network, num_output, kernel_size = [3, 3], strides = [1, 1], padding = 'same', kernel_initializer = tf.random_normal_initializer(stddev=0.01), use_bias = False, kernel_regularizer = l2_regularizer)\n",
    "        network = prelu(network, name = 'name'+ '2')\n",
    "        network = tf.add(input, network)\n",
    "        return network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.infer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "l2_regularizer= tf.contrib.layers.l2_regularizer(1.0)\n",
    "xavier = tf.contrib.layers.xavier_initializer_conv2d() \n",
    "def get_shape(tensor):\n",
    "    static_shape = tensor.shape.as_list()\n",
    "    dynamic_shape = tf.unstack(tf.shape(tensor))\n",
    "    dims = [s[1] if s[0] is None else s[0] for s in zip(static_shape,dynamic_shape)]\n",
    "    return dims\n",
    "def infer(input,embedding_size=512):\n",
    "    with tf.variable_scope('conv1_'):\n",
    "        network = first_conv(input, 64, name = 'conv1')\n",
    "        network = block(network, 'conv1_23', 64)\n",
    "    with tf.variable_scope('conv2_'):\n",
    "        network = first_conv(network, 128, name = 'conv2')\n",
    "        network = block(network, 'conv2_23', 128)\n",
    "        network = block(network, 'conv2_45', 128)\n",
    "    with tf.variable_scope('conv3_'):\n",
    "        network = first_conv(network, 256, name = 'conv3')\n",
    "        network = block(network, 'conv3_23', 256)\n",
    "        network = block(network, 'conv3_45', 256)\n",
    "        network = block(network, 'conv3_67', 256)\n",
    "        network = block(network, 'conv3_89', 256)\n",
    "    with tf.variable_scope('conv4_'):\n",
    "        network = first_conv(network, 512, name = 'conv4')\n",
    "        network = block(network, 'conv4_23', 512)\n",
    "    with tf.variable_scope('feature'):\n",
    "        #BATCH_SIZE = network.get_shape()[0]\n",
    "        dims = get_shape(network)\n",
    "        print(dims)\n",
    "        #BATCH_SIZE = tf.shape(network)[0]\n",
    "        #feature = tf.layers.dense(tf.reshape(network,[BATCH_SIZE, -1]), 512, kernel_regularizer = l2_regularizer, kernel_initializer = xavier)\n",
    "        feature = tf.layers.dense(tf.reshape(network,[dims[0], np.prod(dims[1:])]), embedding_size, kernel_regularizer = l2_regularizer, kernel_initializer = xavier)\n",
    "    return feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 7, 6, 512]\n",
      "(1, 512)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.summary.writer.writer.FileWriter at 0x7ff213e7b790>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "image = tf.random_normal([1,112,96,3])\n",
    "#image = tf.constant(np.random.normal(size=[1,112,96,3]),dtype=tf.float32)\n",
    "feature = infer(image)\n",
    "print feature.get_shape()\n",
    "tf.summary.FileWriter('sphereface_network',tf.get_default_graph())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1)\n"
     ]
    }
   ],
   "source": [
    "pred = tf.layers.dense(feature,1)\n",
    "print pred.get_shape()\n",
    "loss = tf.nn.l2_loss(pred-1)\n",
    "optimizer = tf.train.GradientDescentOptimizer(0.0001).minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.384229\n",
      "0.277823\n",
      "0.0985079\n",
      "0.126074\n",
      "0.0673081\n",
      "0.0297425\n",
      "0.00743761\n",
      "0.0312544\n",
      "0.00901861\n",
      "0.00035515\n",
      "0.00304774\n",
      "0.000781628\n",
      "0.000343543\n",
      "3.64557e-06\n",
      "0.00109801\n",
      "0.00045575\n",
      "8.61884e-05\n",
      "1.42084e-05\n",
      "1.65151e-05\n",
      "0.00298974\n",
      "0.000394176\n",
      "0.000196685\n",
      "0.00237238\n",
      "0.00044498\n",
      "0.000510066\n"
     ]
    }
   ],
   "source": [
    "with tf.Session() as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    for i in xrange(500):\n",
    "        loss_np,_ = sess.run([loss,optimizer])\n",
    "        if i % 20 ==0:\n",
    "            print loss_np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
